package util

import (
	"errors"
	"golang.org/x/sync/singleflight"
	"sync"
	"time"
)

type MemCache struct {
	data       map[string]any
	keys2clear map[string]int64
	sfg        *singleflight.Group
	gcTd       time.Duration
	gcNum      uint64
	lock       sync.RWMutex
}

var ErrMemCacheKeyExists = errors.New(`key已存在`)
var ErrMemCacheKeyNotExist = errors.New(`key不存在`)
var ErrMemCacheKeyExpired = errors.New(`key已过期`)

// NewMemCache 得到一个缓存
func NewMemCache() *MemCache {
	ss := &MemCache{
		data:       make(map[string]any),
		keys2clear: make(map[string]int64),
		gcTd:       time.Minute,
		gcNum:      100,
		sfg:        new(singleflight.Group),
	}
	go func(ss *MemCache) {
		for {
			time.Sleep(ss.gcTd)
			if len(ss.keys2clear) == 0 {
				continue
			}
			if !ss.lock.TryLock() {
				continue
			}
			var i uint64 = 0
			for k, x := range ss.keys2clear {
				if i == ss.gcNum {
					break
				}
				if x > 0 && x < time.Now().Unix() {
					delete(ss.keys2clear, k)
					delete(ss.data, k)
				}
				i++
			}
			ss.lock.Unlock()
		}
	}(ss)
	return ss
}

// GC 过期Key回收的时间间隔和每次的数量
func (s *MemCache) GC(td time.Duration, num uint64) {
	s.lock.Lock()
	defer s.lock.Unlock()
	if td == 0 {
		td = time.Minute
	}
	s.gcTd = td
	if num == 0 {
		num = 100
	}
	s.gcNum = num
}

// Set 设置
func (s *MemCache) Set(k string, v any, exp int64) {
	s.lock.Lock()
	defer s.lock.Unlock()
	if exp > 0 {
		s.keys2clear[k] = time.Now().Unix() + exp
	} else {
		delete(s.keys2clear, k)
	}
	s.data[k] = v
}

// SetNX set nx
func (s *MemCache) SetNX(k string, v any, exp int64) error {
	if _, err := s.Get(k); errors.Is(err, ErrMemCacheKeyNotExist) || errors.Is(err, ErrMemCacheKeyExpired) {
		s.Set(k, v, exp)
		return nil
	}
	return ErrMemCacheKeyExists
}

// ExpireAt 查询有效期，为0时表示永不过期
func (s *MemCache) ExpireAt(k string) int64 {
	return s.keys2clear[k]
}

// SetExpire 设置过期时间
func (s *MemCache) SetExpire(k string, exp int64) error {
	if _, err := s.Get(k); err != nil {
		return err
	}
	s.lock.Lock()
	defer s.lock.Unlock()
	if exp > 0 {
		s.keys2clear[k] = time.Now().Unix() + exp
	} else {
		delete(s.keys2clear, k)
	}
	return nil
}

// Get 获取
func (s *MemCache) Get(k string) (any, error) {
	s.lock.RLock()
	defer s.lock.RUnlock()
	if v, ok := s.data[k]; ok {
		if expAt := s.ExpireAt(k); expAt > 0 && expAt < time.Now().Unix() {
			delete(s.data, k)
			delete(s.keys2clear, k)
			return v, ErrMemCacheKeyExpired
		}
		return v, nil
	}
	return nil, ErrMemCacheKeyNotExist
}

// GetOrSingleDo 合并多个请求为一个，防止击穿
func (s *MemCache) GetOrSingleDo(k string, fn func(k string, mc *MemCache) (any, error)) (any, error) {
	if fn == nil {
		return s.Get(k)
	}
	re, err, _ := s.sfg.Do(k, func() (any, error) {
		v, err := s.Get(k)
		if err != nil {
			return fn(k, s)
		}
		return v, nil
	})
	return re, err
}

// Del 删除
func (s *MemCache) Del(k string) {
	s.lock.Lock()
	defer s.lock.Unlock()
	delete(s.data, k)
	delete(s.keys2clear, k)
}

// Clear 清空
func (s *MemCache) Clear() {
	s.lock.Lock()
	defer s.lock.Unlock()
	clear(s.data)
	clear(s.keys2clear)
}
